{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)\n",
    "\n",
    "<img src=\"cnn_vae.png\" width=\"300\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "PARAMS = {\n",
    "    'max_len': 15,\n",
    "    'vocab_size': 10000,\n",
    "    'embed_dims': 128,\n",
    "    'rnn_size': 128,\n",
    "    'cnn_size': 128,\n",
    "    'latent_size': 16,\n",
    "    'kernel_sz': 3,\n",
    "    'n_hidden_layer': 3,\n",
    "    'clip_norm': 5.0,\n",
    "    'anneal_max': 1.0,\n",
    "    'anneal_bias': 6000,\n",
    "    'batch_size': 128,\n",
    "    'n_epochs': 10,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_vocab(index_from=4):\n",
    "    PARAMS['word2idx'] = tf.keras.datasets.imdb.get_word_index()\n",
    "    PARAMS['word2idx'] = {k: (v + index_from) for k, v in PARAMS['word2idx'].items()}\n",
    "    PARAMS['word2idx']['<pad>'] = 0\n",
    "    PARAMS['word2idx']['<start>'] = 1\n",
    "    PARAMS['word2idx']['<unk>'] = 2\n",
    "    PARAMS['word2idx']['<end>'] = 3\n",
    "    PARAMS['idx2word'] = {i: w for w, i in PARAMS['word2idx'].items()}\n",
    "\n",
    "    \n",
    "def load_data(index_from=4):\n",
    "    (X_train, _), (X_test, _) = tf.contrib.keras.datasets.imdb.load_data(\n",
    "        num_words=PARAMS['vocab_size'], index_from=index_from)\n",
    "    return (X_train, X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = build_vocab()\n",
    "X = np.concatenate(load_data())\n",
    "\n",
    "X = np.concatenate((\n",
    "    tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        X, PARAMS['max_len'], truncating='post', padding='post'),\n",
    "    tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        X, PARAMS['max_len'], truncating='pre', padding='post')))\n",
    "\n",
    "enc_inp = X[:, 1:]\n",
    "dec_inp = X\n",
    "dec_out = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], PARAMS['word2idx']['<end>'])], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kl_w_fn(global_step):\n",
    "    return PARAMS['anneal_max'] * tf.sigmoid((10 / PARAMS['anneal_bias']) * (\n",
    "        tf.to_float(global_step) - tf.constant(PARAMS['anneal_bias'] / 2)))\n",
    "\n",
    "\n",
    "def clip_grads(loss):\n",
    "    variables = tf.trainable_variables()\n",
    "    grads = tf.gradients(loss, variables)\n",
    "    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])\n",
    "    return zip(clipped_grads, variables)\n",
    "\n",
    "\n",
    "def rnn_cell():\n",
    "    return tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],\n",
    "                                  kernel_initializer=tf.orthogonal_initializer())\n",
    "\n",
    "\n",
    "def cnn_block(x, dilation_rate, pad_sz):\n",
    "    pad = tf.zeros([tf.shape(x)[0], pad_sz, x.get_shape()[-1].value])\n",
    "    x =  tf.layers.conv1d(inputs = tf.concat([pad, x, pad], 1),\n",
    "                          filters = PARAMS['cnn_size'],\n",
    "                          kernel_size = PARAMS['kernel_sz'],\n",
    "                          dilation_rate = dilation_rate)\n",
    "    x = x[:, :-pad_sz, :]\n",
    "    x = tf.nn.relu(x)\n",
    "    return x\n",
    "\n",
    "\n",
    "def cnn_forward(x, embedding):\n",
    "    for i in range(PARAMS['n_hidden_layer']):\n",
    "        dilation_rate = 2 ** i\n",
    "        pad_sz = (PARAMS['kernel_sz'] - 1) * dilation_rate\n",
    "        x += cnn_block(x, dilation_rate, pad_sz)\n",
    "        logits = tf.reshape(x, [-1, PARAMS['cnn_size']])\n",
    "        logits = tf.matmul(logits, embedding, transpose_b=True)\n",
    "        logits = tf.reshape(logits, [tf.shape(x)[0], -1, PARAMS['vocab_size']])\n",
    "    return logits\n",
    "\n",
    "\n",
    "def autoregressive(embedding, z, input_proj):\n",
    "    batch_sz = tf.shape(z)[0]\n",
    "    \n",
    "    def cond(i, x, temp):\n",
    "        return i < PARAMS['max_len']\n",
    "\n",
    "    def body(i, x, temp):\n",
    "        sos = tf.fill([batch_sz, 1], PARAMS['word2idx']['<start>'])\n",
    "        x = tf.concat([sos, x[:, :-1]], 1)\n",
    "        \n",
    "        x = tf.nn.embedding_lookup(embedding, x)\n",
    "        x = input_proj(tf.concat([x, z], -1))\n",
    "        logits = cnn_forward(x, embedding)\n",
    "        ids = tf.argmax(logits, -1, output_type=tf.int32)[:, i]\n",
    "        ids = tf.expand_dims(ids, -1)\n",
    "\n",
    "        temp = tf.concat([temp[:, 1:], ids], -1)\n",
    "\n",
    "        x = tf.concat([temp[:, -(i+1):], temp[:, :-(i+1)]], -1)\n",
    "        x = tf.reshape(x, [batch_sz, PARAMS['max_len']])\n",
    "        i += 1\n",
    "        return i, x, temp\n",
    "    \n",
    "    x = tf.zeros([batch_sz, PARAMS['max_len']], tf.int32)\n",
    "    _, res, _ = tf.while_loop(cond, body, [tf.constant(0), x, x])\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(inputs, labels, mode):\n",
    "    is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
    "    enc_seq_len = tf.count_nonzero(inputs, 1, dtype=tf.int32)\n",
    "    batch_sz = tf.shape(inputs)[0]\n",
    "    \n",
    "    with tf.variable_scope('Encoder'):\n",
    "        embedding = tf.get_variable('lookup_table', [PARAMS['vocab_size'],\n",
    "                                                     PARAMS['embed_dims']])\n",
    "        x = tf.nn.embedding_lookup(embedding, inputs)\n",
    "        \n",
    "        _, enc_state = tf.nn.dynamic_rnn(rnn_cell(), x, enc_seq_len, dtype=tf.float32)\n",
    "        \n",
    "        z_mean = tf.layers.dense(enc_state, PARAMS['latent_size'])\n",
    "        z_var = tf.layers.dense(enc_state, PARAMS['latent_size'])\n",
    "        \n",
    "    posterior = tf.contrib.distributions.MultivariateNormalDiag(z_mean, z_var)\n",
    "    prior = tf.contrib.distributions.MultivariateNormalDiag(tf.zeros_like(z_mean),\n",
    "                                                            tf.ones_like(z_var))\n",
    "        \n",
    "    with tf.variable_scope('Decoder'):\n",
    "        input_proj = tf.layers.Dense(PARAMS['cnn_size'], tf.nn.relu)\n",
    "        z = tf.tile(tf.expand_dims(posterior.sample(), 1), [1, PARAMS['max_len'], 1])\n",
    "        \n",
    "        if is_training:\n",
    "            dec_inp = tf.nn.embedding_lookup(embedding, labels['dec_inp'])\n",
    "            x = input_proj(tf.concat([dec_inp, z], -1))\n",
    "            logits = cnn_forward(x, embedding)\n",
    "            return logits, posterior, prior\n",
    "        else:\n",
    "            return autoregressive(embedding, z, input_proj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode):\n",
    "    logits_or_ids = forward(features, labels, mode)        \n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        return tf.estimator.EstimatorSpec(mode, predictions=logits_or_ids)\n",
    "        \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        logits, posterior, prior = logits_or_ids\n",
    "        \n",
    "        out_dist = tf.distributions.Categorical(logits)\n",
    "        \n",
    "        global_step = tf.train.get_global_step()\n",
    "        \n",
    "        nll_loss = - tf.reduce_sum(out_dist.log_prob(labels['dec_out']))\n",
    "        \n",
    "        kl_w = kl_w_fn(global_step)\n",
    "        \n",
    "        kl_loss = tf.reduce_sum(tf.distributions.kl_divergence(posterior, prior))\n",
    "        \n",
    "        loss_op = nll_loss + kl_w * kl_loss\n",
    "        \n",
    "        train_op = tf.train.AdamOptimizer().apply_gradients(\n",
    "            clip_grads(loss_op),\n",
    "            global_step = global_step)\n",
    "        \n",
    "        lth = tf.train.LoggingTensorHook(\n",
    "            {'nll_loss': nll_loss, 'kl_w': kl_w, 'kl_loss': kl_loss}, every_n_iter=100)\n",
    "        \n",
    "        return tf.estimator.EstimatorSpec(\n",
    "            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl\n",
      "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11b432eb8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 17757.797, step = 1\n",
      "INFO:tensorflow:nll_loss = 17684.46, kl_w = 0.006692851, kl_loss = 10957.21\n",
      "INFO:tensorflow:global_step/sec: 2.17765\n",
      "INFO:tensorflow:loss = 10933.477, step = 101 (45.922 sec)\n",
      "INFO:tensorflow:nll_loss = 10900.135, kl_w = 0.007897083, kl_loss = 4222.0977 (45.922 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.26024\n",
      "INFO:tensorflow:loss = 10842.261, step = 201 (44.243 sec)\n",
      "INFO:tensorflow:nll_loss = 10804.835, kl_w = 0.009315956, kl_loss = 4017.3525 (44.243 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.20411\n",
      "INFO:tensorflow:loss = 10243.188, step = 301 (45.370 sec)\n",
      "INFO:tensorflow:nll_loss = 10192.387, kl_w = 0.010986943, kl_loss = 4623.8013 (45.370 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.18017\n",
      "INFO:tensorflow:loss = 9709.822, step = 401 (45.868 sec)\n",
      "INFO:tensorflow:nll_loss = 9650.805, kl_w = 0.012953726, kl_loss = 4556.036 (45.868 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.23163\n",
      "INFO:tensorflow:loss = 9265.691, step = 501 (44.810 sec)\n",
      "INFO:tensorflow:nll_loss = 9190.885, kl_w = 0.015267149, kl_loss = 4899.8467 (44.810 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.18961\n",
      "INFO:tensorflow:loss = 8816.474, step = 601 (45.670 sec)\n",
      "INFO:tensorflow:nll_loss = 8726.945, kl_w = 0.01798621, kl_loss = 4977.6035 (45.670 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.33169\n",
      "INFO:tensorflow:loss = 8734.211, step = 701 (42.887 sec)\n",
      "INFO:tensorflow:nll_loss = 8631.089, kl_w = 0.021179108, kl_loss = 4869.068 (42.888 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 782 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2049.7031.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i saw this movie for this movie i think that the <unk> <unk> i <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is a fan of time i just seen it movie about <unk> <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 783 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 9291.661, step = 783\n",
      "INFO:tensorflow:nll_loss = 9169.796, kl_w = 0.024205629, kl_loss = 5034.592\n",
      "INFO:tensorflow:global_step/sec: 2.23615\n",
      "INFO:tensorflow:loss = 8433.8125, step = 883 (44.722 sec)\n",
      "INFO:tensorflow:nll_loss = 8289.746, kl_w = 0.028470589, kl_loss = 5060.168 (44.721 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.24842\n",
      "INFO:tensorflow:loss = 8156.293, step = 983 (44.475 sec)\n",
      "INFO:tensorflow:nll_loss = 7992.166, kl_w = 0.033461247, kl_loss = 4904.9897 (44.475 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.19425\n",
      "INFO:tensorflow:loss = 8059.392, step = 1083 (45.574 sec)\n",
      "INFO:tensorflow:nll_loss = 7862.749, kl_w = 0.039291352, kl_loss = 5004.737 (45.574 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.14054\n",
      "INFO:tensorflow:loss = 7877.4155, step = 1183 (46.717 sec)\n",
      "INFO:tensorflow:nll_loss = 7650.47, kl_w = 0.04608883, kl_loss = 4924.08 (46.717 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.13802\n",
      "INFO:tensorflow:loss = 8841.221, step = 1283 (46.772 sec)\n",
      "INFO:tensorflow:nll_loss = 8583.125, kl_w = 0.053996176, kl_loss = 4779.885 (46.772 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.04279\n",
      "INFO:tensorflow:loss = 8351.722, step = 1383 (48.953 sec)\n",
      "INFO:tensorflow:nll_loss = 8054.1445, kl_w = 0.06317033, kl_loss = 4710.71 (48.953 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.18052\n",
      "INFO:tensorflow:loss = 8543.8, step = 1483 (45.860 sec)\n",
      "INFO:tensorflow:nll_loss = 8207.943, kl_w = 0.07378165, kl_loss = 4552.0293 (45.860 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1564 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2099.5452.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-1564\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i was so bad and this movie is that all and all of time <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is a lot of movies that don't waste of this movie <unk> <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-1564\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1565 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 8396.324, step = 1565\n",
      "INFO:tensorflow:nll_loss = 8022.4976, kl_w = 0.08368247, kl_loss = 4467.2085\n",
      "INFO:tensorflow:global_step/sec: 2.21291\n",
      "INFO:tensorflow:loss = 8478.651, step = 1665 (45.190 sec)\n",
      "INFO:tensorflow:nll_loss = 8043.9316, kl_w = 0.09738124, kl_loss = 4464.1055 (45.190 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.07123\n",
      "INFO:tensorflow:loss = 8133.461, step = 1765 (48.281 sec)\n",
      "INFO:tensorflow:nll_loss = 7641.1675, kl_w = 0.11304584, kl_loss = 4354.8145 (48.281 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.21284\n",
      "INFO:tensorflow:loss = 7868.1704, step = 1865 (45.191 sec)\n",
      "INFO:tensorflow:nll_loss = 7321.884, kl_w = 0.13086486, kl_loss = 4174.4326 (45.191 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.02446\n",
      "INFO:tensorflow:loss = 8176.475, step = 1965 (49.396 sec)\n",
      "INFO:tensorflow:nll_loss = 7531.4424, kl_w = 0.15101443, kl_loss = 4271.332 (49.395 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.13666\n",
      "INFO:tensorflow:loss = 7860.392, step = 2065 (46.802 sec)\n",
      "INFO:tensorflow:nll_loss = 7156.92, kl_w = 0.17364664, kl_loss = 4051.1694 (46.802 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.22284\n",
      "INFO:tensorflow:loss = 8611.154, step = 2165 (44.987 sec)\n",
      "INFO:tensorflow:nll_loss = 7857.939, kl_w = 0.19887616, kl_loss = 3787.358 (44.987 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.18977\n",
      "INFO:tensorflow:loss = 8461.563, step = 2265 (45.667 sec)\n",
      "INFO:tensorflow:nll_loss = 7666.117, kl_w = 0.22676536, kl_loss = 3507.7935 (45.667 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 2346 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2221.3452.\n",
      "INFO:tensorflow:Calling model_fn.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-2346\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i was a fan of time i was a long time at all it <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is a waste of time i recommend it of the same time <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-2346\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 2347 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 8440.029, step = 2347\n",
      "INFO:tensorflow:nll_loss = 7583.853, kl_w = 0.25161827, kl_loss = 3402.6782\n",
      "INFO:tensorflow:global_step/sec: 2.24946\n",
      "INFO:tensorflow:loss = 8202.8545, step = 2447 (44.456 sec)\n",
      "INFO:tensorflow:nll_loss = 7316.755, kl_w = 0.2842792, kl_loss = 3117.0063 (44.456 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.24717\n",
      "INFO:tensorflow:loss = 8300.074, step = 2547 (44.501 sec)\n",
      "INFO:tensorflow:nll_loss = 7341.061, kl_w = 0.31937042, kl_loss = 3002.8237 (44.500 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.12767\n",
      "INFO:tensorflow:loss = 8463.965, step = 2647 (47.001 sec)\n",
      "INFO:tensorflow:nll_loss = 7400.3125, kl_w = 0.35663486, kl_loss = 2982.4695 (47.001 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.03059\n",
      "INFO:tensorflow:loss = 9004.262, step = 2747 (49.246 sec)\n",
      "INFO:tensorflow:nll_loss = 7917.4688, kl_w = 0.39571938, kl_loss = 2746.374 (49.246 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.01268\n",
      "INFO:tensorflow:loss = 9210.676, step = 2847 (49.685 sec)\n",
      "INFO:tensorflow:nll_loss = 8192.643, kl_w = 0.43618327, kl_loss = 2333.9585 (49.685 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.15748\n",
      "INFO:tensorflow:loss = 8884.15, step = 2947 (46.351 sec)\n",
      "INFO:tensorflow:nll_loss = 7810.6533, kl_w = 0.47751516, kl_loss = 2248.0908 (46.351 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.08138\n",
      "INFO:tensorflow:loss = 9062.841, step = 3047 (48.045 sec)\n",
      "INFO:tensorflow:nll_loss = 7976.9478, kl_w = 0.5191573, kl_loss = 2091.645 (48.045 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 3128 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2154.5928.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-3128\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i saw this movie and the first time i had to be <unk> in <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is a lot of movies or even on the people br br <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-3128\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 3129 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 8548.301, step = 3129\n",
      "INFO:tensorflow:nll_loss = 7406.037, kl_w = 0.55313194, kl_loss = 2065.0845\n",
      "INFO:tensorflow:global_step/sec: 2.26606\n",
      "INFO:tensorflow:loss = 8541.703, step = 3229 (44.131 sec)\n",
      "INFO:tensorflow:nll_loss = 7347.823, kl_w = 0.5938731, kl_loss = 2010.3286 (44.130 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.24079\n",
      "INFO:tensorflow:loss = 8523.439, step = 3329 (44.627 sec)\n",
      "INFO:tensorflow:nll_loss = 7305.766, kl_w = 0.6333619, kl_loss = 1922.5557 (44.627 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.2027\n",
      "INFO:tensorflow:loss = 8285.988, step = 3429 (45.399 sec)\n",
      "INFO:tensorflow:nll_loss = 7008.6104, kl_w = 0.6711373, kl_loss = 1903.3037 (45.399 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.0039\n",
      "INFO:tensorflow:loss = 8297.673, step = 3529 (49.903 sec)\n",
      "INFO:tensorflow:nll_loss = 7074.6333, kl_w = 0.7068222, kl_loss = 1730.3353 (49.903 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.04945\n",
      "INFO:tensorflow:loss = 8950.645, step = 3629 (48.794 sec)\n",
      "INFO:tensorflow:nll_loss = 7741.7915, kl_w = 0.7401343, kl_loss = 1633.2881 (48.794 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.13941\n",
      "INFO:tensorflow:loss = 9194.217, step = 3729 (46.742 sec)\n",
      "INFO:tensorflow:nll_loss = 8051.4688, kl_w = 0.7708882, kl_loss = 1482.3782 (46.741 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.03637\n",
      "INFO:tensorflow:loss = 8771.532, step = 3829 (49.107 sec)\n",
      "INFO:tensorflow:nll_loss = 7677.521, kl_w = 0.79899096, kl_loss = 1369.2411 (49.107 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 3910 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2217.6265.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-3910\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i have a fan of the time i give it a 1 10 10 <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is a waste of time at least it is just worth watching <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-3910\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 3911 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 8917.375, step = 3911\n",
      "INFO:tensorflow:nll_loss = 7857.8936, kl_w = 0.82004714, kl_loss = 1291.9761\n",
      "INFO:tensorflow:global_step/sec: 2.20433\n",
      "INFO:tensorflow:loss = 8533.211, step = 4011 (45.367 sec)\n",
      "INFO:tensorflow:nll_loss = 7465.018, kl_w = 0.84334546, kl_loss = 1266.614 (45.367 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.09661\n",
      "INFO:tensorflow:loss = 8578.28, step = 4111 (47.697 sec)\n",
      "INFO:tensorflow:nll_loss = 7569.371, kl_w = 0.8641271, kl_loss = 1167.5469 (47.697 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.0179\n",
      "INFO:tensorflow:loss = 9465.794, step = 4211 (49.556 sec)\n",
      "INFO:tensorflow:nll_loss = 8467.34, kl_w = 0.88253593, kl_loss = 1131.3463 (49.556 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.17251\n",
      "INFO:tensorflow:loss = 9026.106, step = 4311 (46.030 sec)\n",
      "INFO:tensorflow:nll_loss = 8121.599, kl_w = 0.89874285, kl_loss = 1006.4143 (46.030 sec)\n",
      "INFO:tensorflow:global_step/sec: 1.99471\n",
      "INFO:tensorflow:loss = 9163.89, step = 4411 (50.132 sec)\n",
      "INFO:tensorflow:nll_loss = 8329.456, kl_w = 0.9129343, kl_loss = 914.01294 (50.132 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.02579\n",
      "INFO:tensorflow:loss = 8948.267, step = 4511 (49.364 sec)\n",
      "INFO:tensorflow:nll_loss = 8179.454, kl_w = 0.92530197, kl_loss = 830.87744 (49.364 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.22869\n",
      "INFO:tensorflow:loss = 8946.656, step = 4611 (44.869 sec)\n",
      "INFO:tensorflow:nll_loss = 8206.183, kl_w = 0.93603605, kl_loss = 791.07336 (44.869 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 4692 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2038.3258.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-4692\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i saw this movie i was a good movie i have been a big <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is really awful but there are one of the movies i've seen <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-4692\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 4693 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 9238.526, step = 4693\n",
      "INFO:tensorflow:nll_loss = 8553.328, kl_w = 0.94374704, kl_loss = 726.0397\n",
      "INFO:tensorflow:global_step/sec: 2.25378\n",
      "INFO:tensorflow:loss = 8580.965, step = 4793 (44.371 sec)\n",
      "INFO:tensorflow:nll_loss = 7952.576, kl_w = 0.95196813, kl_loss = 660.09454 (44.371 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.2761\n",
      "INFO:tensorflow:loss = 8895.101, step = 4893 (43.935 sec)\n",
      "INFO:tensorflow:nll_loss = 8320.993, kl_w = 0.9590399, kl_loss = 598.62756 (43.935 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.2459\n",
      "INFO:tensorflow:loss = 8403.889, step = 4993 (44.526 sec)\n",
      "INFO:tensorflow:nll_loss = 7822.6953, kl_w = 0.9651086, kl_loss = 602.2049 (44.526 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.13543\n",
      "INFO:tensorflow:loss = 8460.121, step = 5093 (46.829 sec)\n",
      "INFO:tensorflow:nll_loss = 7891.7383, kl_w = 0.97030604, kl_loss = 585.77716 (46.829 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.1263\n",
      "INFO:tensorflow:loss = 8444.383, step = 5193 (47.030 sec)\n",
      "INFO:tensorflow:nll_loss = 7946.053, kl_w = 0.97474945, kl_loss = 511.2388 (47.030 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.12127\n",
      "INFO:tensorflow:loss = 8076.477, step = 5293 (47.142 sec)\n",
      "INFO:tensorflow:nll_loss = 7626.3735, kl_w = 0.9785427, kl_loss = 459.97333 (47.142 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.15419\n",
      "INFO:tensorflow:loss = 8949.117, step = 5393 (46.421 sec)\n",
      "INFO:tensorflow:nll_loss = 8534.508, kl_w = 0.9817768, kl_loss = 422.30548 (46.421 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 5474 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2357.3142.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-5474\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i think it is a bit of the best movies of the movie ever <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is not worth a watch for the movie that it is worth <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-5474\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 5475 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 8640.12, step = 5475\n",
      "INFO:tensorflow:nll_loss = 8271.754, kl_w = 0.98406756, kl_loss = 374.33014\n",
      "INFO:tensorflow:global_step/sec: 2.30633\n",
      "INFO:tensorflow:loss = 8474.056, step = 5575 (43.360 sec)\n",
      "INFO:tensorflow:nll_loss = 8105.4375, kl_w = 0.9864804, kl_loss = 373.67014 (43.360 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.25803\n",
      "INFO:tensorflow:loss = 8225.799, step = 5675 (44.286 sec)\n",
      "INFO:tensorflow:nll_loss = 7858.8994, kl_w = 0.98853207, kl_loss = 371.15564 (44.286 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.21343\n",
      "INFO:tensorflow:loss = 8198.463, step = 5775 (45.179 sec)\n",
      "INFO:tensorflow:nll_loss = 7868.3145, kl_w = 0.9902755, kl_loss = 333.39062 (45.179 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.1621\n",
      "INFO:tensorflow:loss = 8170.575, step = 5875 (46.251 sec)\n",
      "INFO:tensorflow:nll_loss = 7861.3584, kl_w = 0.9917561, kl_loss = 311.78693 (46.251 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.1055\n",
      "INFO:tensorflow:loss = 8702.276, step = 5975 (47.495 sec)\n",
      "INFO:tensorflow:nll_loss = 8425.444, kl_w = 0.99301285, kl_loss = 278.77972 (47.495 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.02406\n",
      "INFO:tensorflow:loss = 8765.69, step = 6075 (49.406 sec)\n",
      "INFO:tensorflow:nll_loss = 8521.092, kl_w = 0.9940791, kl_loss = 246.05542 (49.406 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.10132\n",
      "INFO:tensorflow:loss = 8707.773, step = 6175 (47.589 sec)\n",
      "INFO:tensorflow:nll_loss = 8495.986, kl_w = 0.99498355, kl_loss = 212.8547 (47.589 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 6256 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2126.1506.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-6256\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i would recommend this movie to anyone who likes to see this movie again <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this is a great film and i hope to see it br br <unk> <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-6256\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 6257 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 8432.456, step = 6257\n",
      "INFO:tensorflow:nll_loss = 8218.91, kl_w = 0.9956215, kl_loss = 214.48544\n",
      "INFO:tensorflow:global_step/sec: 2.09656\n",
      "INFO:tensorflow:loss = 8413.566, step = 6357 (47.698 sec)\n",
      "INFO:tensorflow:nll_loss = 8218.102, kl_w = 0.9962913, kl_loss = 196.19244 (47.698 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.13854\n",
      "INFO:tensorflow:loss = 8299.255, step = 6457 (46.761 sec)\n",
      "INFO:tensorflow:nll_loss = 8114.4907, kl_w = 0.99685884, kl_loss = 185.34625 (46.761 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.18123\n",
      "INFO:tensorflow:loss = 8365.56, step = 6557 (45.845 sec)\n",
      "INFO:tensorflow:nll_loss = 8157.5864, kl_w = 0.9973398, kl_loss = 208.5275 (45.846 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.14873\n",
      "INFO:tensorflow:loss = 8372.05, step = 6657 (46.539 sec)\n",
      "INFO:tensorflow:nll_loss = 8173.4062, kl_w = 0.99774724, kl_loss = 199.09207 (46.539 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.12001\n",
      "INFO:tensorflow:loss = 8207.27, step = 6757 (47.170 sec)\n",
      "INFO:tensorflow:nll_loss = 8013.0303, kl_w = 0.99809235, kl_loss = 194.61057 (47.170 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.02291\n",
      "INFO:tensorflow:loss = 8114.836, step = 6857 (49.434 sec)\n",
      "INFO:tensorflow:nll_loss = 7946.503, kl_w = 0.99838483, kl_loss = 168.60544 (49.434 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.17923\n",
      "INFO:tensorflow:loss = 8415.801, step = 6957 (45.888 sec)\n",
      "INFO:tensorflow:nll_loss = 8276.776, kl_w = 0.9986324, kl_loss = 139.21448 (45.888 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 7038 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2180.5276.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-7038\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i have to say that it was a good movie to see it again <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: i think that i was a fan of the <unk> of the movie ever <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-7038\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 7039 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:loss = 8059.9185, step = 7039\n",
      "INFO:tensorflow:nll_loss = 7915.2476, kl_w = 0.99880695, kl_loss = 144.8439\n",
      "INFO:tensorflow:global_step/sec: 2.00921\n",
      "INFO:tensorflow:loss = 8049.7905, step = 7139 (49.772 sec)\n",
      "INFO:tensorflow:nll_loss = 7920.04, kl_w = 0.9989899, kl_loss = 129.88147 (49.772 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.22952\n",
      "INFO:tensorflow:loss = 8429.317, step = 7239 (44.853 sec)\n",
      "INFO:tensorflow:nll_loss = 8308.171, kl_w = 0.9991448, kl_loss = 121.249725 (44.852 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.12369\n",
      "INFO:tensorflow:loss = 8471.212, step = 7339 (47.088 sec)\n",
      "INFO:tensorflow:nll_loss = 8349.805, kl_w = 0.999276, kl_loss = 121.49566 (47.088 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.15853\n",
      "INFO:tensorflow:loss = 8364.232, step = 7439 (46.328 sec)\n",
      "INFO:tensorflow:nll_loss = 8263.13, kl_w = 0.999387, kl_loss = 101.16458 (46.328 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.22185\n",
      "INFO:tensorflow:loss = 8160.001, step = 7539 (45.008 sec)\n",
      "INFO:tensorflow:nll_loss = 8048.138, kl_w = 0.99948114, kl_loss = 111.921104 (45.008 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.22178\n",
      "INFO:tensorflow:loss = 7860.17, step = 7639 (45.009 sec)\n",
      "INFO:tensorflow:nll_loss = 7729.9297, kl_w = 0.9995608, kl_loss = 130.29723 (45.009 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.15855\n",
      "INFO:tensorflow:loss = 8154.5186, step = 7739 (46.327 sec)\n",
      "INFO:tensorflow:nll_loss = 8039.818, kl_w = 0.9996282, kl_loss = 114.74347 (46.327 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 7820 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 2073.8198.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpt6rxeanl/model.ckpt-7820\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i saw this movie when i was a kid i was disappointed when i <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: i have to say that this is a good movie i have ever seen <end>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def inf_inp(test_strs):\n",
    "    x = [[PARAMS['word2idx'].get(w, 2) for w in s.split()] for s in test_strs]\n",
    "    x = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        x, PARAMS['max_len'], truncating='post', padding='post')\n",
    "    return x\n",
    "\n",
    "def demo(test_strs, pred_ids):\n",
    "    for s, pred in zip(test_strs, pred_ids):\n",
    "        print('\\nOriginal:', s)\n",
    "        print('Reconstr:', ' '.join([PARAMS['idx2word'].get(idx, '<unk>') for idx in pred]))\n",
    "\n",
    "\n",
    "test_strs = ['i love this film and i think it is one of the best films',\n",
    "             'this movie is a waste of time and there is no point to watch it']\n",
    "\n",
    "estimator = tf.estimator.Estimator(model_fn)\n",
    "\n",
    "for _ in range(PARAMS['n_epochs']):\n",
    "    estimator.train(tf.estimator.inputs.numpy_input_fn(\n",
    "        x = enc_inp,\n",
    "        y = {'dec_inp': dec_inp, 'dec_out': dec_out},\n",
    "        batch_size = PARAMS['batch_size'],\n",
    "        shuffle = True))\n",
    "    \n",
    "    pred_ids = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(\n",
    "        x = inf_inp(test_strs),\n",
    "        shuffle = False)))\n",
    "\n",
    "    demo(test_strs, pred_ids)\n",
    "    \n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
